#include "direct_method.h"
#include <chrono>
#include <iostream>
#include <boost/format.hpp>

namespace MultiColSLAM
{
    using namespace std;
    // double fx = 718.856, fy = 718.856, cx = 607.1928, cy = 185.2157;
    // lafida 相机参数
    // double fx = 197.79890327251547, fy = 198.2113151444225, cx = 369.6758863001525, cy = 219.93224923014523;
    // 309 相机参数
    double fx = 348.85011420800447, fy = 350.1616195267597, cx = 334.95775905376985, cy = 247.21215908355472;
    double baseline = 0.573;
    // double baseline = 0.3;
    string left_file = "../left.png";
    string disparity_file = "../disparity.png";
    boost::format fmt_others("../%06d.png");

    JacobianAccumulator::JacobianAccumulator(const cv::Mat &img1_, const cv::Mat &img2_,
                                             const VecVector2d &px_ref_, const std::vector<double> depth_ref_,
                                             Eigen::Isometry3d &T21_) : img1(img1_), img2(img2_), px_ref(px_ref_), depth_ref(depth_ref_), T21(T21_)
    {
        projection = VecVector2d(px_ref.size(), Eigen::Vector2d(0, 0));
    }

    void JacobianAccumulator::accumulate_jacobian(const cv::Range &range)
    {
        // parameters
        const int half_patch_size = 1;
        int cnt_good = 0;
        Matrix6d hessian = Matrix6d::Zero();
        Vector6d bias = Vector6d::Zero();
        double cost_tmp = 0;

        for (size_t i = range.start; i < range.end; i++)
        {

            // compute the projection in the second image
            Eigen::Vector3d point_ref =
                depth_ref[i] * Eigen::Vector3d((px_ref[i][0] - cx) / fx, (px_ref[i][1] - cy) / fy, 1);
            Eigen::Vector3d point_cur = T21 * point_ref;
            if (point_cur[2] < 0) // depth invalid
                continue;

            float u = fx * point_cur[0] / point_cur[2] + cx, v = fy * point_cur[1] / point_cur[2] + cy;
            if (u < half_patch_size || u > img2.cols - half_patch_size || v < half_patch_size ||
                v > img2.rows - half_patch_size)
                continue;

            projection[i] = Eigen::Vector2d(u, v);
            double X = point_cur[0], Y = point_cur[1], Z = point_cur[2],
                   Z2 = Z * Z, Z_inv = 1.0 / Z, Z2_inv = Z_inv * Z_inv;
            cnt_good++;

            // and compute error and jacobian
            for (int x = -half_patch_size; x <= half_patch_size; x++)
                for (int y = -half_patch_size; y <= half_patch_size; y++)
                {

                    double error = GetPixelValue(img1, px_ref[i][0] + x, px_ref[i][1] + y) -
                                   GetPixelValue(img2, u + x, v + y);
                    Matrix26d J_pixel_xi;
                    Eigen::Vector2d J_img_pixel;

                    J_pixel_xi(0, 0) = fx * Z_inv;
                    J_pixel_xi(0, 1) = 0;
                    J_pixel_xi(0, 2) = -fx * X * Z2_inv;
                    J_pixel_xi(0, 3) = -fx * X * Y * Z2_inv;
                    J_pixel_xi(0, 4) = fx + fx * X * X * Z2_inv;
                    J_pixel_xi(0, 5) = -fx * Y * Z_inv;

                    J_pixel_xi(1, 0) = 0;
                    J_pixel_xi(1, 1) = fy * Z_inv;
                    J_pixel_xi(1, 2) = -fy * Y * Z2_inv;
                    J_pixel_xi(1, 3) = -fy - fy * Y * Y * Z2_inv;
                    J_pixel_xi(1, 4) = fy * X * Y * Z2_inv;
                    J_pixel_xi(1, 5) = fy * X * Z_inv;

                    J_img_pixel = Eigen::Vector2d(
                        0.5 * (GetPixelValue(img2, u + 1 + x, v + y) - GetPixelValue(img2, u - 1 + x, v + y)),
                        0.5 * (GetPixelValue(img2, u + x, v + 1 + y) - GetPixelValue(img2, u + x, v - 1 + y)));

                    // total jacobian
                    Vector6d J = -1.0 * (J_img_pixel.transpose() * J_pixel_xi).transpose();

                    hessian += J * J.transpose();
                    bias += -error * J;
                    cost_tmp += error * error;
                }
        }

        if (cnt_good)
        {
            // set hessian, bias and cost
            unique_lock<mutex> lck(hessian_mutex);
            H += hessian;
            b += bias;
            cost += cost_tmp / cnt_good;
        }
    }

    Matrix6d JacobianAccumulator::hessian() const { return H; }
    Vector6d JacobianAccumulator::bias() const { return b; }
    double JacobianAccumulator::cost_func() const { return cost; }
    VecVector2d JacobianAccumulator::projected_points() const { return projection; }

    void JacobianAccumulator::reset()
    {
        H = Matrix6d::Zero();
        b = Vector6d::Zero();
        cost = 0;
    }

    DirectMethod::DirectMethod() {}

    double DirectMethod::DirectPoseEstimationSingleLayer(const cv::Mat &img1, const cv::Mat &img2,
                                                       const VecVector2d &px_ref, const std::vector<double> depth_ref,
                                                       Eigen::Isometry3d &T21)
    {
        const int iterations = 10;
        double cost = 0, lastCost = 0;
        auto t1 = chrono::steady_clock::now();
        JacobianAccumulator jaco_accu(img1, img2, px_ref, depth_ref, T21);

        for (int iter = 0; iter < iterations; iter++)
        {   
            jaco_accu.reset();
            cv::parallel_for_(cv::Range(0, px_ref.size()), JacobianAccumulatorLoopBody(&jaco_accu));            
            Matrix6d H = jaco_accu.hessian();
            Vector6d b = jaco_accu.bias();

            // Check if H is singular
            if (H.determinant() < 1e-10) {
                std::cerr << "Hessian matrix is nearly singular, stopping iteration" << std::endl;
                break;
            }

            // solve update and put it into estimation
            // Vector6d update = H.ldlt().solve(b);
            Eigen::Matrix<double, 6, 1> update = H.ldlt().solve(b);

            // Check for NaN in update
            if (std::isnan(update[0])) {
                std::cerr << "Update contains NaN values, stopping iteration" << std::endl;
                break;
            }

            // Normalize rotation update
            double rotation_norm = update.tail<3>().norm();
            if (rotation_norm < 1e-10) {
                std::cerr << "Rotation norm too small, skipping iteration" << std::endl;
                break;
            }

        
            Eigen::Matrix3d dR = Eigen::AngleAxisd(update.tail<3>().norm(), update.tail<3>().normalized()).toRotationMatrix();
            Eigen::Vector3d dt = update.head<3>();
            Eigen::Isometry3d transform_update = Eigen::Isometry3d::Identity();
            transform_update.rotate(dR);
            transform_update.pretranslate(dt);
            T21 = transform_update * T21;

            cout << T21.matrix() << endl;
            

            cost = jaco_accu.cost_func();

            if (std::isnan(update[0]))
            {
                // sometimes occurred when we have a black or white patch and H is irreversible
                // cout << "update is nan" << endl;
                break;
            }
            if (iter > 0 && cost > lastCost)
            {
                // cout << "cost increased: " << cost << ", " << lastCost << endl;
                break;
            }
            if (update.norm() < 1e-3)
            {
                // converge
                break;
            }

            lastCost = cost;
            // cout << "iteration: " << iter << ", cost: " << cost << endl;
        }

        // cout << "T21 = \n"
        //      << T21.matrix() << endl;
        // auto t2 = chrono::steady_clock::now();
        // auto time_used = chrono::duration_cast<chrono::duration<double>>(t2 - t1);
        // cout << "direct method for single layer: " << time_used.count() << endl;

        // plot the projected pixels here
        // cv::Mat img2_show;
        // cv::cvtColor(img2, img2_show, cv::COLOR_GRAY2BGR);
        // VecVector2d projection = jaco_accu.projected_points();
        // for (size_t i = 0; i < px_ref.size(); ++i)
        // {
        //     auto p_ref = px_ref[i];
        //     auto p_cur = projection[i];
        //     if (p_cur[0] > 0 && p_cur[1] > 0)
        //     {
        //         cv::circle(img2, cv::Point2f(p_cur[0], p_cur[1]), 2, cv::Scalar(0, 250, 0), 2);
        //         cv::line(img2, cv::Point2f(p_ref[0], p_ref[1]), cv::Point2f(p_cur[0], p_cur[1]),
        //                  cv::Scalar(0, 250, 0));
        //     }
        // }
        // cv::imshow("current", img2);
        // cv::waitKey();
        return cost;
    }

    double DirectMethod::DirectPoseEstimationMultiLayer(const cv::Mat &img1, const cv::Mat &img2,
                                                      const VecVector2d &px_ref, const std::vector<double> depth_ref,
                                                      Eigen::Isometry3d &T21)
    {
        // parameters
        int pyramids = 4;
        double pyramid_scale = 0.5;
        double scales[] = {1.0, 0.5, 0.25, 0.125};
        double total_cost = 0.0;

        // create pyramids
        vector<cv::Mat> pyr1, pyr2; // image pyramids
        for (int i = 0; i < pyramids; i++)
        {
            if (i == 0)
            {
                pyr1.push_back(img1);
                pyr2.push_back(img2);
            }
            else
            {
                cv::Mat img1_pyr, img2_pyr;
                cv::resize(pyr1[i - 1], img1_pyr,
                           cv::Size(pyr1[i - 1].cols * pyramid_scale, pyr1[i - 1].rows * pyramid_scale));
                cv::resize(pyr2[i - 1], img2_pyr,
                           cv::Size(pyr2[i - 1].cols * pyramid_scale, pyr2[i - 1].rows * pyramid_scale));
                pyr1.push_back(img1_pyr);
                pyr2.push_back(img2_pyr);
            }
        }

        double fxG = fx, fyG = fy, cxG = cx, cyG = cy; // backup the old values
        for (int level = pyramids - 1; level >= 0; level--)
        {
            VecVector2d px_ref_pyr; // set the keypoints in this pyramid level
            for (auto &px : px_ref)
            {
                px_ref_pyr.push_back(scales[level] * px);
            }

            // scale fx, fy, cx, cy in different pyramid levels
            fx = fxG * scales[level];
            fy = fyG * scales[level];
            cx = cxG * scales[level];
            cy = cyG * scales[level];
            total_cost += DirectPoseEstimationSingleLayer(pyr1[level], pyr2[level], px_ref_pyr, depth_ref, T21);
        }
        return total_cost;
    }

    float GetPixelValue(const cv::Mat &img, float x, float y)
    {
        // boundary check
        if (x < 0)
            x = 0;
        if (y < 0)
            y = 0;
        if (x >= img.cols)
            x = img.cols - 1;
        if (y >= img.rows)
            y = img.rows - 1;
        uchar *data = &img.data[int(y) * img.step + int(x)];
        float xx = x - floor(x);
        float yy = y - floor(y);
        return float(
            (1 - xx) * (1 - yy) * data[0] +
            xx * (1 - yy) * data[1] +
            (1 - xx) * yy * data[img.step] +
            xx * yy * data[img.step + 1]);
    }

}